{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "98a1786c",
   "metadata": {},
   "source": [
    "# Continue Analysis on Generated Charts\n",
    "\n",
    "While TableGPT2 excels in data analysis tasks, it currently lacks built-in support for visual modalities. Many data analysis tasks involve visualization, so to address this limitation, we provide an interface for integrating your own Visual Language Model (VLM) plugin.\n",
    "\n",
    "When the agent performs a visualization task—typically using `matplotlib.pyplot.show`—the VLM will take over from the LLM, offering a more nuanced summarization of the visualization. This approach avoids the common pitfalls of LLMs in visualization tasks, which often either state, \"I have plotted the data,\" or hallucinating the content of the plot.\n",
    "\n",
    "We continue using the agent from the previous section to perform a data visualization task and observe its final output.\n",
    "> **NOTE** Before you start, you can install Chinese fonts using the following command:\n",
    "```bash\n",
    "apt-get update && apt-get install -y --no-install-recommends fonts-noto-cjk\n",
    "mplfonts init\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "15aba93a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import date\n",
    "from typing import TypedDict\n",
    "\n",
    "from langchain_core.messages import HumanMessage\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from pybox import AsyncLocalPyBoxManager\n",
    "from tablegpt import DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR\n",
    "from tablegpt.agent import create_tablegpt_graph\n",
    "from tablegpt.agent.file_reading import Stage\n",
    "\n",
    "llm = ChatOpenAI(openai_api_base=\"YOUR_VLLM_URL\", openai_api_key=\"whatever\", model_name=\"TableGPT2-7B\")\n",
    "pybox_manager = AsyncLocalPyBoxManager(profile_dir=DEFAULT_TABLEGPT_IPYKERNEL_PROFILE_DIR)\n",
    "checkpointer = MemorySaver()\n",
    "\n",
    "agent = create_tablegpt_graph(\n",
    "    llm=llm,\n",
    "    pybox_manager=pybox_manager,\n",
    "    checkpointer=checkpointer,\n",
    "    session_id=\"some-session-id\", # This is required when using file-reading\n",
    ")\n",
    "\n",
    "class Attachment(TypedDict):\n",
    "    \"\"\"Contains at least one dictionary with the key filename.\"\"\"\n",
    "    filename: str\n",
    "\n",
    "attachment_msg = HumanMessage(\n",
    "    content=\"\",\n",
    "    # Please make sure your iPython kernel can access your filename.\n",
    "    additional_kwargs={\"attachments\": [Attachment(filename=\"titanic.csv\")]},\n",
    ")\n",
    "\n",
    "# Reading and processing files.\n",
    "response = await agent.ainvoke(\n",
    "    input={\n",
    "        \"entry_message\": attachment_msg,\n",
    "        \"processing_stage\": Stage.UPLOADED,\n",
    "        \"messages\": [attachment_msg],\n",
    "        \"parent_id\": \"some-parent-id1\",\n",
    "        \"date\": date.today(),\n",
    "    },\n",
    "    config={\n",
    "        # Using checkpointer requires binding thread_id at runtime.\n",
    "        \"configurable\": {\"thread_id\": \"some-thread-id\"},\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0afbab13",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "content=\"好的，我将基于性别绘制一个饼图，以展示每个性别的人数。首先，我们需要统计每个性别的人数，然后使用 `seaborn` 和 `matplotlib` 来绘制饼图。\\n\\n```python\\nimport seaborn as sns\\nimport matplotlib.pyplot as plt\\n\\n# Count the number of people for each gender\\ngender_counts = df['Sex'].value_counts()\\n\\n# Create a pie chart\\nplt.figure(figsize=(8, 6))\\nplt.pie(gender_counts, labels=gender_counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette('pastel'))\\nplt.title('Gender Distribution')\\nplt.show()\\n```\" additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-6115fe22-3b55-4d85-be09-6c31a59736f6'\n",
      "content=[{'type': 'text', 'text': '```pycon\\n<Figure size 800x600 with 1 Axes>\\n```'}, {'type': 'image_url', 'image_url': {'url': '...'}}] name='python' id='226ba8f2-29a7-4706-9178-8cb5b4062488' tool_call_id='03eb1113-6aed-4e0a-a3c0-4cc0043a55ee' artifact=[]\n",
      "content='饼图已经成功生成。' additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-83468bd1-9451-4c78-91a3-b0f96ffa169a'\n"
     ]
    }
   ],
   "source": [
    "# Define the human message that asks the model to draw a pie chart based on gender data\n",
    "human_message = HumanMessage(content=\"Draw a pie chart based on gender and the number of people of each gender.\")\n",
    "\n",
    "async for event in agent.astream_events(\n",
    "    input={\n",
    "        \"messages\": [human_message],\n",
    "        \"parent_id\": \"some-parent-id2\",\n",
    "        \"date\": date.today(),\n",
    "    },\n",
    "    version=\"v2\",\n",
    "    # We configure the same thread_id to use checkpoints to retrieve the memory of the last run.\n",
    "    config={\"configurable\": {\"thread_id\": \"some-thread-id\"}},\n",
    "):\n",
    "    evt = event[\"event\"]\n",
    "    if evt == \"on_chat_model_end\":\n",
    "        print(event[\"data\"][\"output\"])\n",
    "    elif event[\"name\"] == \"tool_node\" and evt == \"on_chain_stream\":\n",
    "        for lc_msg in event[\"data\"][\"chunk\"][\"messages\"]:\n",
    "            print(lc_msg)\n",
    "    else:\n",
    "        # Handle other events here\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c428aca",
   "metadata": {},
   "source": [
    "Now let's set up the Visual Language Model (VLM) and create a new agent with VLM support:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "425633b7-14a4-4bbc-91e1-d94161a41682",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the VLM instance\n",
    "vlm = ChatOpenAI(openai_api_base=\"YOUR_VLM_URL\", openai_api_key=\"whatever\", model_name=\"YOUR_MODEL_NAME\")\n",
    "\n",
    "# Assume llm, pybox_manager, and memory_saver are defined elsewhere\n",
    "agent_with_vlm = create_tablegpt_graph(\n",
    "    llm=llm,\n",
    "    pybox_manager=pybox_manager,\n",
    "    vlm=vlm,\n",
    "    checkpointer=checkpointer,\n",
    "    session_id=\"some-session-id\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a19cb4-adbc-49de-90af-4d43e77d4308",
   "metadata": {},
   "source": [
    "We use a [time travel](https://langchain-ai.github.io/langgraph/tutorials/introduction/#part-7-time-travel) feature to go back to before the last time the agent gave an answer, to avoid past memories hallucinating the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3652d131-6ed7-4d75-bfe2-152ba40fb090",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_history = agent.get_state_history(config={\"configurable\": {\"thread_id\": \"some-thread-id\"}})\n",
    "\n",
    "to_replay = None\n",
    "for state in list(state_history)[::-1]:\n",
    "    if state.next and state.next[0] == \"__start__\":\n",
    "        to_replay = state"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a82aeef-7906-45b8-a1b0-2d2b3c18451b",
   "metadata": {},
   "source": [
    "Send the same question to the model via the new agent with VLM support"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e138cb4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "content=\"好的，我将绘制一个饼图来展示数据集中男性和女性乘客的数量。\\n```python\\n# Count the number of passengers by gender\\ngender_counts = df['Sex'].value_counts()\\n\\n# Plot a pie chart\\nplt.figure(figsize=(8, 6))\\nplt.pie(gender_counts, labels=gender_counts.index, autopct='%1.1f%%', startangle=140)\\nplt.title('Gender Distribution')\\nplt.show()\\n```\\n\" additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'TableGPT2-7B'} id='run-2d05b2ab-32f4-481f-8fa5-43c78515d9c3'\n",
      "content=[{'type': 'text', 'text': '```pycon\\n<Figure size 800x600 with 1 Axes>\\n```'}, {'type': 'image_url', 'image_url': {'url': '...'}}] name='python' id='51a99935-b0b1-496d-9a45-c1f318104773' tool_call_id='918d57ee-7362-4e0d-8d66-64b7e57ecaf6' artifact=[]\n",
      "content='饼图显示数据集中性别分布为 50% 女性和 50% 男性，这表明男性和女性乘客数量相等。' additional_kwargs={} response_metadata={'finish_reason': 'stop', 'model_name': 'qwen2-vl-7b-instruct'} id='run-d9b0e891-f03c-40c8-8474-9fef7511c40b'\n"
     ]
    }
   ],
   "source": [
    "async for event in agent_with_vlm.astream_events(\n",
    "    None,\n",
    "    to_replay.config,\n",
    "    version=\"v2\",\n",
    "):\n",
    "    evt = event[\"event\"]\n",
    "    if evt == \"on_chat_model_end\":\n",
    "        print(event[\"data\"][\"output\"])\n",
    "    elif event[\"name\"] == \"tool_node\" and evt == \"on_chain_stream\":\n",
    "        for lc_msg in event[\"data\"][\"chunk\"][\"messages\"]:\n",
    "            print(lc_msg)\n",
    "    else:\n",
    "        # Handle other events here\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20d009cb",
   "metadata": {},
   "source": [
    "We observe that the answer provided by the agent with VLM support is significantly more detailed, including a comprehensive description of the generated images."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
